package me.yricky.kaguya.compile.verilog

import me.yricky.kaguya.base.*
import me.yricky.kaguya.base.lu.*
import me.yricky.kaguya.compile.ICompiler


class KaguyaVerilogCompiler :ICompiler{
    companion object{
        val keywords = setOf(
            "input","output","wire","reg","always","module","begin","end","endmodule","fork","join","for","initial"
        )
    }
    override fun compile(module: Module): String {
        val sb = StringBuilder()
        sb.append("module ${module.name}(\n")
        module.io.apply {
            inputs.forEach {
                assert(!keywords.contains(it.name))
                sb.append("    input ")
                if(it.width > 1){ sb.append("[${it.width-1}:0] ") }
                sb.append(it.name)
                if(it != inputs.last() || outputs.isNotEmpty()){
                    sb.append(",\n")
                }else{
                    sb.append("\n")
                }
            }
            outputs.forEach {
                assert(!keywords.contains(it.name)){
                    "请不要将信号命名为关键字:${it.name}"
                }
                sb.append("    output ")
                if(it is Reg){ sb.append("reg ") }
                if(it.width > 1){ sb.append("[${it.width-1}:0] ") }
                sb.append(it.name)
                if(it != outputs.last()){
                    sb.append(",\n")
                }else{
                    sb.append("\n")
                }
            }
            sb.append(");\n")
        }

        module.internalSignal.mapNotNull { it as? NamedSignal }.forEach {
            assert(!keywords.contains(it.name))
            if(it is Reg){
                sb.append("    reg ")
            } else {
                sb.append("    wire ")
            }
            if(it.width > 1){ sb.append("[${it.width-1}:0] ") }
            sb.append("${it.name};\n")
        }

        module.logicUnits.forEach {
            sb.append(logicUnitToString(it))
        }

        sb.append("endmodule")
        return sb.toString()
    }

    fun logicUnitToString(logicUnit: LogicUnit): String{
        return when(logicUnit){
            is Assign -> "    assign ${logicUnit.output.name} = ${expressionOf(logicUnit.input).removeBracket()};\n"
            is ExpressionLogicUnit -> ""
            is CombUnit -> compileCombUnit(logicUnit)
            is SeqUnit -> compileSeqUnit(logicUnit)
            is ModuleUnit -> {
                val sb = StringBuilder()
                sb.append("    ${logicUnit.innerModule.name} ${logicUnit.instName}(\n")
                logicUnit.innerModule.io.apply {
                    inputs.forEach {
                        sb.append("        .${it.name}(${logicUnit.inputs[it.name]!!.name})")
                        sb.append(if(it == inputs.last() && outputs.isEmpty()) "\n" else ",\n")
                    }
                    outputs.forEach {
                        sb.append("        .${it.name}(${logicUnit.outputs[it.name]!!.name})")
                        sb.append(if(it == outputs.last()) "\n" else ",\n")
                    }
                }
                sb.append("    );\n")
                sb.toString()
            }
        }
    }

    fun compileCombUnit(lu: CombUnit):String{
        val sb = StringBuilder()
        sb.append("    always @(*) begin\n")
        lu.logicBlocks.forEach { ab ->
            if(ab.cases.isEmpty()){
                sb.append("        ${ab.target.name} = ${expressionOf(ab.default)};\n")
            }else{
                ab.cases.forEach { cf ->
                    if(cf == ab.cases.first()){
                        sb.append("        if (${expressionOf(cf.case)})\n")
                    } else {
                        sb.append("        else if (${expressionOf(cf.case)})\n")
                    }
                    sb.append("            ${ab.target.name} = ${expressionOf(cf.from)};\n")
                    sb.append("        else\n")
                    sb.append("            ${ab.target.name} = ${expressionOf(ab.default)};\n")
                }
            }
        }
        sb.append("    end\n")
        return sb.toString()
    }

    fun compileSeqUnit(lu: SeqUnit):String{
        val sb = StringBuilder()
        val trig = lu.trigCases
            .map { if(it is SeqUnit.PosEdge) "posedge ${expressionOf(it.signal)}" else "negedge ${expressionOf(it.signal)}" }
            .reduce{ s1, s2 -> "$s1 or $s2"}
        sb.append("    always @($trig) begin\n")
        lu.assignBlocks.forEach { ab ->
            if(ab.cases.isEmpty()){
                sb.append("        ${ab.target.name} = ${expressionOf(ab.default)};\n")
            }else{
                ab.cases.forEach { cf ->
                    if(cf == ab.cases.first()){
                        sb.append("        if (${expressionOf(cf.case).removeBracket()})\n")
                    } else {
                        sb.append("        else if (${expressionOf(cf.case).removeBracket()})\n")
                    }
                    sb.append("            ${ab.target.name} <= ${expressionOf(cf.from)};\n")
                }
                sb.append("        else\n")
                sb.append("            ${ab.target.name} <= ${expressionOf(ab.default)};\n")
            }
        }
        sb.append("    end\n")
        return sb.toString()
    }

    fun expressionOf(signal: Signal):String{
        return when(signal){
            is ConstSignal -> when(signal.format){
                ConstSignal.Hex -> "${signal.width}'h${String.format("%x",signal.value)}"
                ConstSignal.Dec -> "${signal.width}'d${String.format("%d",signal.value)}"
            }
            is NamedSignal -> signal.name
            is ExpressionWire -> when(val lu = signal.logicUnit){
                is And -> "(${lu.inputSignals.map {
                    expressionOf(it).let { exp ->
                        if(it is ExpressionWire && it.logicUnit is And){
                            exp.removeBracket()
                        } else exp
                    }
                }.reduce{ s1,s2 -> "$s1 & $s2" }})"
                is Inv -> "(~${expressionOf(lu.input)})"
                is Or -> "(${lu.inputSignals.map {
                    expressionOf(it).let { exp ->
                        if(it is ExpressionWire && it.logicUnit is Or){
                            exp.removeBracket()
                        } else exp
                    }
                }.reduce{ s1,s2 -> "$s1 | $s2" }})"
                is Slice -> {
                    if(lu.sliceWidth == 1){
                        "${lu.input.name}[${lu.start}]"
                    }else{
                        "${lu.input.name}[${lu.start + lu.sliceWidth - 1}:${lu.start}]"
                    }
                }
                is Xor -> "(${lu.inputSignals.map { 
                    expressionOf(it).let { exp ->
                        if(it is ExpressionWire && it.logicUnit is Xor){
                            exp.removeBracket()
                        } else exp
                    }
                }.reduce{ s1,s2 -> "$s1 ^ $s2" }})"
                is Repeat -> "{${lu.repeat}{${expressionOf(lu.input).removeBracket('{','}').removeBracket()}}}"
                is Joint -> "{${lu.inputSignals.map { expressionOf(it).removeBracket() }.reduce{ s1, s2 -> "$s1,$s2" }}}"
                is SAnd -> "(&${expressionOf(lu.input)})"
                is SOr -> "(|${expressionOf(lu.input)})"
                is SXor -> "(^${expressionOf(lu.input)})"
                is Plus -> "(${expressionOf(lu.inputs.first)} + ${expressionOf(lu.inputs.second)})"
                is Equal -> "(${expressionOf(lu.inputs.first)} == ${expressionOf(lu.inputs.second)})"
                is LessThan -> "(${expressionOf(lu.inputs.first)} < ${expressionOf(lu.inputs.second)})"
                is MoreThan -> "(${expressionOf(lu.inputs.first)} > ${expressionOf(lu.inputs.second)})"
            }
        }
    }

    fun String.removeBracket(open:Char = '(',close:Char = ')'):String {
        if((this.firstOrNull() != open) || (this.lastOrNull() != close)){
            return this
        }
        var depth = 0
        forEachIndexed { i, c ->
            if(c == open){
                depth++
            }else if(c == close){
                depth--
            }
            if(i != length-1 && depth == 0){
                return this
            }
        }
        return if(depth == 0){
            substring(1, length - 1)
        }else{
            this
        }
    }
}